#!/usr/bin/env python3
"""
summarize_counts.py

Aggregate per-seed outputs into:
  - results/flip_rates_by_context.csv   (MEAN rates across seeds, row-normalized)
  - results/flip_counts_summary.csv     (if inputs were counts -> summed counts; if inputs were rates -> copy of aggregated rates)
Also writes:
  - results/stability_metrics.json
"""
import argparse, json
from pathlib import Path
import numpy as np
import pandas as pd

CORE = ["rate_IN_to_CS","rate_CS_to_ON","rate_ON_to_CS","rate_CS_to_IN"]
ROWS = ["rowsum_IN","rowsum_CS","rowsum_ON"]
STAY = ["rate_IN_to_IN","rate_CS_to_CS","rate_ON_to_ON"]

def counts_to_rates(df: pd.DataFrame) -> pd.DataFrame:
    eps = 1e-12
    out_IN = df["count_IN_to_CS"] + df["count_IN_to_IN"]
    out_CS = df["count_CS_to_ON"] + df["count_CS_to_IN"] + df["count_CS_to_CS"]
    out_ON = df["count_ON_to_CS"] + df["count_ON_to_ON"]
    rates = pd.DataFrame({
        "seed": df["seed"], "n": df["n"],
        "rate_IN_to_CS": df["count_IN_to_CS"] / np.maximum(eps, out_IN),
        "rate_CS_to_ON": df["count_CS_to_ON"] / np.maximum(eps, out_CS),
        "rate_ON_to_CS": df["count_ON_to_CS"] / np.maximum(eps, out_ON),
        "rate_CS_to_IN": df["count_CS_to_IN"] / np.maximum(eps, out_CS),
        "rate_IN_to_IN": df["count_IN_to_IN"] / np.maximum(eps, out_IN),
        "rate_CS_to_CS": df["count_CS_to_CS"] / np.maximum(eps, out_CS),
        "rate_ON_to_ON": df["count_ON_to_ON"] / np.maximum(eps, out_ON),
        "rowsum_IN": (df["count_IN_to_CS"] + df["count_IN_to_IN"]) / np.maximum(eps, out_IN),
        "rowsum_CS": (df["count_CS_to_ON"] + df["count_CS_to_IN"] + df["count_CS_to_CS"]) / np.maximum(eps, out_CS),
        "rowsum_ON": (df["count_ON_to_CS"] + df["count_ON_to_ON"]) / np.maximum(eps, out_ON),
    })
    return rates

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--in", dest="inputs", nargs="+", required=True,
                    help="paths to per-seed flip_counts_summary.csv OR flip_rates_by_context.csv")
    ap.add_argument("--out", required=True, help="flip_rates_by_context.csv (aggregated mean, renormalized)")
    ap.add_argument("--summary", required=True, help="flip_counts_summary.csv (summed counts if counts given; else copy of agg rates)")
    ap.add_argument("--metrics_out", default="results/stability_metrics.json")
    ap.add_argument("--coverage_from", default="D_values.csv", help="file (with column n) to check coverage against")
    args = ap.parse_args()

    frames_rates, frames_counts = [], []
    for p in args.inputs:
        df = pd.read_csv(p)
        if "rate_IN_to_CS" in df.columns:
            frames_rates.append(df.copy())
        else:
            frames_counts.append(df.copy())
            frames_rates.append(counts_to_rates(df))

    all_rates = pd.concat(frames_rates, ignore_index=True)

    # Stability: max absolute delta across seeds for core rates per n
    max_delta = 0.0
    for n, grp in all_rates.groupby("n"):
        for c in CORE:
            vals = grp[c].values
            if len(vals) >= 2:
                d = float(np.max(vals) - np.min(vals))
                max_delta = max(max_delta, d)

    # Normalization residuals (how close row sums are to 1)
    max_resid = float(np.max(np.abs(all_rates[ROWS].values - 1.0)))

    # Coverage versus D_values.csv (if present)
    want = None
    cov_ok = True
    cov_list = sorted([float(x) for x in all_rates["n"].unique()])
    cov_missing = []
    cov_extra   = []
    cov_from = Path(args.coverage_from)
    if cov_from.exists():
        dfwant = pd.read_csv(cov_from)
        if "n" in dfwant.columns:
            want = sorted([float(x) for x in dfwant["n"].unique()])
            have = set(cov_list)
            cov_missing = [x for x in want if x not in have]
            cov_extra   = [x for x in cov_list if x not in set(want)]
            cov_ok = (len(cov_missing) == 0)

    # Aggregate: mean across seeds, then renormalize rows per source to ensure row-stochastic
    mean = all_rates.groupby("n").mean(numeric_only=True).reset_index()

    # renormalize per n and per source group
    eps = 1e-12
    def renorm(group: pd.DataFrame) -> pd.DataFrame:
        # nothing to do here; group is single row. We recompute rowsums and divide core+stay by sum to enforce 1.0
        r = group.copy()
        for src, cols in [("IN", ["rate_IN_to_IN","rate_IN_to_CS"]),
                          ("CS", ["rate_CS_to_IN","rate_CS_to_CS","rate_CS_to_ON"]),
                          ("ON", ["rate_ON_to_ON","rate_ON_to_CS"])]:
            s = float(np.sum([r[c].values[0] for c in cols]))
            if s > eps:
                for c in cols:
                    r.loc[:, c] = r[c] / s
            r.loc[:, f"rowsum_{src}"] = 1.0
        return r

    agg_rows = []
    for _, row in mean.iterrows():
        r = row.to_frame().T  # single-row frame
        r = renorm(r)
        agg_rows.append(r)
    agg = pd.concat(agg_rows, ignore_index=True)

    # Write outputs
    Path(args.out).parent.mkdir(parents=True, exist_ok=True)
    Path(args.summary).parent.mkdir(parents=True, exist_ok=True)
    agg_out = agg[["n","rate_IN_to_CS","rate_CS_to_ON","rate_ON_to_CS","rate_CS_to_IN",
                   "rate_IN_to_IN","rate_CS_to_CS","rate_ON_to_ON","rowsum_IN","rowsum_CS","rowsum_ON"]].copy()
    agg_out.to_csv(args.out, index=False)

    if frames_counts:
        # Summed counts across seeds
        cnt = pd.concat(frames_counts, ignore_index=True)
        sums = cnt.groupby("n").sum(numeric_only=True).reset_index()
        sums.to_csv(args.summary, index=False)
    else:
        # No counts given; keep a copy of aggregated rates as the "summary"
        agg_out.to_csv(args.summary, index=False)

    metrics = {
        "max_delta_rate": float(max_delta),
        "max_rowsum_residual": float(max_resid),
        "coverage_n": cov_list,
        "coverage_expected": want,
        "coverage_missing": cov_missing,
        "coverage_extra": cov_extra,
        "num_seeds": int(all_rates["seed"].nunique()) if "seed" in all_rates.columns else 1
    }
    with open(args.metrics_out, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2)

    print(f"max Δrate across seeds = {max_delta:.4f}")
    print(f"max |row-sum−1| residual = {max_resid:.3e}")
    if want is not None:
        print(f"coverage expected = {want}")
    print(f"coverage seen = {cov_list}")
    print(f"missing = {cov_missing}; extra = {cov_extra}")
    print(f"wrote {args.out} and {args.summary}")
    print(f"metrics → {args.metrics_out}")

if __name__ == "__main__":
    main()
